Home > NoiseTools > nt_regw (original).m

nt_regw (original)

PURPOSE ^

[b,z]=nt_regw(y,x,w) - weighted regression

SYNOPSIS ^

function [b,z]=nt_regw(y,x,w)

DESCRIPTION ^

[b,z]=nt_regw(y,x,w) - weighted regression

  b: regression matrix (apply to x to approximate y)
  z: regression (x*b)

  y: data
  x: regressor
  w: weight to apply to y

  w is either a matrix of same size as y, or a column vector to be applied
  to each column of y

 NoiseTools

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [b,z]=nt_regw(y,x,w)
0002 %[b,z]=nt_regw(y,x,w) - weighted regression
0003 %
0004 %  b: regression matrix (apply to x to approximate y)
0005 %  z: regression (x*b)
0006 %
0007 %  y: data
0008 %  x: regressor
0009 %  w: weight to apply to y
0010 %
0011 %  w is either a matrix of same size as y, or a column vector to be applied
0012 %  to each column of y
0013 %
0014 % NoiseTools
0015 
0016 PCA_THRESH=0.0000001; % discard dimensions of x with eigenvalue lower than this
0017 
0018 if nargin<3; w=[]; end
0019 if nargin<2; error('!'); end
0020 
0021 %% check/fix sizes
0022 m=size(y,1);
0023 x=nt_unfold(x);
0024 y=nt_unfold(y);
0025 if size(x,1)~=size(y,1); 
0026     disp(size(x)); disp(size(y)); error('!'); 
0027 end
0028 
0029 %% save weighted mean
0030 if nargout>1
0031     mn=y-nt_demean(y,w);
0032 end
0033 
0034 %%
0035 if isempty(w) 
0036     %% simple regression
0037     xx=nt_demean(x);
0038     yy=nt_demean(y);
0039     [V,D]=eig(xx'*xx); V=real(V); D=real(D);
0040     topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0041     xxx=xx*topcs;
0042     b=(yy'*xxx) / (xxx'*xxx); b=b';
0043     if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0044 else
0045     %% weighted regression
0046     if size(w,1)~=size(x,1); error('!'); end
0047     if size(w,2)==1; 
0048         %% same weight for all channels
0049         if sum(w(:))==0; 
0050             error('weights all zero');
0051         end
0052         yy=nt_demean(y,w).*repmat(w,1,size(y,2)); 
0053         xx=nt_demean(x,w).*repmat(w,1,size(x,2));  
0054         [V,D]=eig(xx'*xx); V=real(V); D=real(D); D=diag(D);
0055         topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0056         xxx=xx*topcs;
0057         b=(yy'*xxx) / (xxx'*xxx); b=b';
0058         if nargout>1; z=nt_demean(x,w)*topcs*b; z=z+mn; end
0059     else
0060         %% each channel has own weight
0061         if size(w,2) ~= size(y,2); error('!'); end 
0062         if nargout; z=zeros(size(y)); end
0063         for iChan=1:size(y,2)
0064             if sum(w(:,iChan))==0; %disp(iChan);
0065                 error('weights all zero'); 
0066             else
0067                 yy=nt_demean(y(:,iChan),w(:,iChan)) .* w(:,iChan); 
0068                 x=nt_demean(x,w(:,iChan)); % remove channel-specific-weighted mean from regressor
0069                 xx=x.*repmat(w(:,iChan),1,size(x,2)); 
0070                 [V,D]=eig(xx'*xx); V=real(V); D=real(diag(D));
0071                 topcs=V(:,find(D/max(D) > PCA_THRESH)); % discard weak dims
0072                 xxx=xx*topcs;
0073                 b(iChan,1:size(topcs,2))=(yy'*xxx) / (xxx'*xxx); 
0074             end
0075             if nargout>1; z(:,iChan)=x*(topcs*b(iChan,1:size(topcs,2))') + mn(:,iChan); end
0076         end
0077     end             
0078 end
0079 
0080 %%
0081 if nargout>1;
0082     z=nt_fold(z,m);
0083 end
0084 
0085 %% test code
0086 if 0
0087     % basic
0088     x=randn(100,10); y=randn(100,10); 
0089     b1=nt_regw(x,y); b2=nt_regw(x,x); b3=nt_regw(x,y,ones(size(x))); 
0090     figure(1); subplot 131; nt_imagescc(b1); subplot 132; nt_imagescc(b2); subplot 133; nt_imagescc(b3);
0091 end
0092 if 0
0093     % fit random walk
0094     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0095     [b,z]=nt_regw(y,x); 
0096     figure(1); clf; plot([y,z]);
0097 end
0098 if 0
0099     % weights, random
0100     y=cumsum(randn(1000,1)); x=(1:1000)'; x=[x,x.^2,x.^3];
0101     w=rand(size(y));
0102     [b,z]=nt_regw(y,x,w); 
0103     figure(1); clf; plot([y,z]);
0104 end
0105 if 0
0106     % weights, 1st vs 2nd half
0107     y=cumsum(randn(1000,1))+1000; x=(1:1000)'; x=[x,x.^2,x.^3];
0108     w=ones(size(y)); w(1:500,:)=0;
0109     [b,z]=nt_regw(y,x,w); 
0110     figure(1); clf; plot([y,z]);
0111 end
0112 if 0
0113     % multichannel
0114     y=cumsum(randn(1000,2)); x=(1:1000)'; x=[x,x.^2,x.^3];
0115     w=ones(size(y)); 
0116     [b,z]=nt_regw(y,x,w); 
0117     figure(1); clf; plot([y,z]);
0118 end
0119

Generated on Sat 29-Apr-2023 17:15:46 by m2html © 2005